import click
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import dgl

from logzero import logger
from pathlib import Path
from ruamel.yaml import YAML
from torch.utils.data import DataLoader
from dgl.data.utils import Subset
from sklearn.metrics import mean_absolute_error
from qm9 import QM9
from modules.initializers import GlorotOrthogonal
from modules.dimenet import DimeNet
from modules.dimenet_pp import DimeNetPP

def split_dataset(dataset, num_train, num_valid, shuffle=False, random_state=None):
    """Split dataset into training, validation and test set.

    Parameters
    ----------
    dataset
        We assume that ``len(dataset)`` gives the number of datapoints and ``dataset[i]``
        gives the ith datapoint.
    num_train : int
        Number of training datapoints.
    num_valid : int
        Number of validation datapoints.
    shuffle : bool, optional
        By default we perform a consecutive split of the dataset. If True,
        we will first randomly shuffle the dataset.
    random_state : None, int or array_like, optional
        Random seed used to initialize the pseudo-random number generator.
        This can be any integer between 0 and 2^32 - 1 inclusive, an array
        (or other sequence) of such integers, or None (the default value).
        If seed is None, then RandomState will try to read data from /dev/urandom
        (or the Windows analogue) if available or seed from the clock otherwise.

    Returns
    -------
    list of length 3
        Subsets for training, validation and test.
    """
    from itertools import accumulate
    num_data = len(dataset)
    assert num_train + num_valid < num_data
    lengths = [num_train, num_valid, num_data - num_train - num_valid]
    if shuffle:
        indices = np.random.RandomState(seed=random_state).permutation(num_data)
    else:
        indices = np.arange(num_data)
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(accumulate(lengths), lengths)]

@torch.no_grad()
def ema(ema_model, model, decay):
    msd = model.state_dict()
    for k, ema_v in ema_model.state_dict().items():
        model_v = msd[k].detach()
        ema_v.copy_(ema_v * decay + (1. - decay) * model_v)

def edge_init(edges):
    R_src, R_dst = edges.src['R'], edges.dst['R']
    dist = torch.sqrt(F.relu(torch.sum((R_src - R_dst) ** 2, -1)))
    # d: bond length, o: bond orientation
    return {'d': dist, 'o': R_src - R_dst}

def _collate_fn(batch):
    graphs, line_graphs, labels = map(list, zip(*batch))
    g, l_g = dgl.batch(graphs), dgl.batch(line_graphs)
    labels = torch.tensor(labels, dtype=torch.float32)
    return g, l_g, labels

def train(device, model, opt, loss_fn, train_loader):
    model.train()
    epoch_loss = 0
    num_samples = 0

    for g, l_g, labels in train_loader:
        g = g.to(device)
        l_g = l_g.to(device)
        labels = labels.to(device)
        logits = model(g, l_g)
        loss = loss_fn(logits, labels.view([-1, 1]))
        epoch_loss += loss.data.item() * len(labels)
        num_samples += len(labels)
        opt.zero_grad()
        loss.backward()
        opt.step()

    return epoch_loss / num_samples

@torch.no_grad()
def evaluate(device, model, valid_loader):
    model.eval()
    predictions_all, labels_all = [], []
    
    for g, l_g, labels in valid_loader:
        g = g.to(device)
        l_g = l_g.to(device)
        logits = model(g, l_g)
        labels_all.extend(labels)
        predictions_all.extend(logits.view(-1,).cpu().numpy())
    
    return np.array(predictions_all), np.array(labels_all)

@click.command()
@click.option('-m', '--model-cnf', type=click.Path(exists=True), help='Path of model config yaml.')
def main(model_cnf):
    yaml = YAML(typ='safe')
    model_cnf = yaml.load(Path(model_cnf))
    model_name, model_params, train_params, pretrain_params = model_cnf['name'], model_cnf['model'], model_cnf['train'], model_cnf['pretrain']
    logger.info(f'Model name: {model_name}')
    logger.info(f'Model params: {model_params}')
    logger.info(f'Train params: {train_params}')

    if model_params['targets'] in ['mu', 'homo', 'lumo', 'gap', 'zpve']:
        model_params['output_init'] = nn.init.zeros_
    else:
        # 'GlorotOrthogonal' for alpha, R2, U0, U, H, G, and Cv
        model_params['output_init'] = GlorotOrthogonal

    logger.info('Loading Data Set')
    dataset = QM9(label_keys=model_params['targets'], edge_funcs=[edge_init])

    # data split
    train_data, valid_data, test_data = split_dataset(dataset,
                                                      num_train=train_params['num_train'],
                                                      num_valid=train_params['num_valid'],
                                                      shuffle=True,
                                                      random_state=train_params['data_seed'])
    logger.info(f'Size of Training Set: {len(train_data)}')
    logger.info(f'Size of Validation Set: {len(valid_data)}')
    logger.info(f'Size of Test Set: {len(test_data)}')

    # data loader
    train_loader = DataLoader(train_data,
                              batch_size=train_params['batch_size'],
                              shuffle=True,
                              collate_fn=_collate_fn,
                              num_workers=train_params['num_workers'])

    valid_loader = DataLoader(valid_data,
                              batch_size=train_params['batch_size'],
                              shuffle=False,
                              collate_fn=_collate_fn,
                              num_workers=train_params['num_workers'])

    test_loader = DataLoader(test_data,
                             batch_size=train_params['batch_size'],
                             shuffle=False,
                             collate_fn=_collate_fn,
                             num_workers=train_params['num_workers'])

    # check cuda
    gpu = train_params['gpu']
    device = f'cuda:{gpu}' if gpu >= 0 and torch.cuda.is_available() else 'cpu'

    # model initialization
    logger.info('Loading Model')
    if model_name == 'dimenet':
        model = DimeNet(emb_size=model_params['emb_size'],
                        num_blocks=model_params['num_blocks'],
                        num_bilinear=model_params['num_bilinear'],
                        num_spherical=model_params['num_spherical'],
                        num_radial=model_params['num_radial'],
                        cutoff=model_params['cutoff'],
                        envelope_exponent=model_params['envelope_exponent'],
                        num_before_skip=model_params['num_before_skip'],
                        num_after_skip=model_params['num_after_skip'],
                        num_dense_output=model_params['num_dense_output'],
                        num_targets=len(model_params['targets']),
                        output_init=model_params['output_init']).to(device)
    elif model_name == 'dimenet++':
        model = DimeNetPP(emb_size=model_params['emb_size'],
                          out_emb_size=model_params['out_emb_size'],
                          int_emb_size=model_params['int_emb_size'],
                          basis_emb_size=model_params['basis_emb_size'],
                          num_blocks=model_params['num_blocks'],
                          num_spherical=model_params['num_spherical'],
                          num_radial=model_params['num_radial'],
                          cutoff=model_params['cutoff'],
                          envelope_exponent=model_params['envelope_exponent'],
                          num_before_skip=model_params['num_before_skip'],
                          num_after_skip=model_params['num_after_skip'],
                          num_dense_output=model_params['num_dense_output'],
                          num_targets=len(model_params['targets']),
                          extensive=model_params['extensive'],
                          output_init=model_params['output_init']).to(device)
    else:
        raise ValueError(f'Invalid Model Name {model_name}')

    if pretrain_params['flag']:
        torch_path = pretrain_params['path']
        target = model_params['targets'][0]
        model.load_state_dict(torch.load(f'{torch_path}/{target}.pt'))

        logger.info('Testing with Pretrained model')
        predictions, labels = evaluate(device, model, test_loader)
        test_mae = mean_absolute_error(labels, predictions)
        logger.info(f'Test MAE {test_mae:.4f}')

        return
    # define loss function and optimization
    loss_fn = nn.L1Loss()
    opt = optim.Adam(model.parameters(), lr=train_params['lr'], weight_decay=train_params['weight_decay'], amsgrad=True)
    scheduler = optim.lr_scheduler.StepLR(opt, train_params['step_size'], gamma=train_params['gamma'])

    # model training
    best_mae = 1e9
    no_improvement = 0
    
    # EMA for valid and test
    logger.info('EMA Init')
    ema_model = copy.deepcopy(model)
    for p in ema_model.parameters():
        p.requires_grad_(False)
    
    best_model = copy.deepcopy(ema_model)

    logger.info('Training')
    for i in range(train_params['epochs']):
        train_loss = train(device, model, opt, loss_fn, train_loader)
        ema(ema_model, model, train_params['ema_decay'])
        if i % train_params['interval'] == 0:
            predictions, labels = evaluate(device, ema_model, valid_loader)

            valid_mae = mean_absolute_error(labels, predictions)
            logger.info(f'Epoch {i} | Train Loss {train_loss:.4f} | Val MAE {valid_mae:.4f}')

            if valid_mae > best_mae:
                no_improvement += 1
                if no_improvement == train_params['early_stopping']:
                    logger.info('Early stop.')
                    break
            else:
                no_improvement = 0
                best_mae = valid_mae
                best_model = copy.deepcopy(ema_model)
        else:
            logger.info(f'Epoch {i} | Train Loss {train_loss:.4f}')
        
        scheduler.step()

    logger.info('Testing')
    predictions, labels = evaluate(device, best_model, test_loader)
    test_mae = mean_absolute_error(labels, predictions)
    logger.info('Test MAE {:.4f}'.format(test_mae))

if __name__ == "__main__":
    main()
